{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MFEM Example 9\n",
    "\n",
    "Adapted from [PyMFEM/ex9.py]( https://github.com/mfem/PyMFEM/blob/master/examples/ex9.py).\n",
    "Compare with the [original Example 9](https://github.com/mfem/mfem/blob/master/examples/ex9.cpp) in MFEM.\n",
    "\n",
    "This example code solves the time-dependent advection equation\n",
    "\n",
    "\\begin{equation*}\n",
    "\\frac{\\partial u}{\\partial t} + v \\cdot \\nabla u = 0,\n",
    "\\label{advection}\\tag{1}\n",
    "\\end{equation*}\n",
    "\n",
    "where $v$ is a given fluid velocity, and \n",
    "\n",
    "\\begin{equation*}\n",
    "u_0(x)=u(0,x)\n",
    "\\label{advection2}\\tag{2}\n",
    "\\end{equation*}\n",
    "\n",
    "is a given initial condition.\n",
    "\n",
    "The example demonstrates the use of Discontinuous Galerkin (DG) bilinear forms in MFEM (face integrators), the use of explicit ODE time integrators, the definition of periodic boundary conditions through periodic meshes, as well as the use of GLVis for persistent visualization of a time-evolving solution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "from os.path import expanduser, join\n",
    "import time\n",
    "import numpy as np\n",
    "from numpy import sqrt, pi, cos, sin, hypot, arctan2\n",
    "from scipy.special import erfc\n",
    "\n",
    "# Requires pymfem and pyglvis see https://github.com/glvis/pyglvis\n",
    "import mfem.ser as mfem\n",
    "from mfem.ser import intArray\n",
    "from glvis import glvis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Setup problem parameters\n",
    "\n",
    "problem = 0\n",
    "order = 3\n",
    "vis_steps = 5\n",
    "ode_solver_type = 4\n",
    "sleep_interval = .05\n",
    "\n",
    "if problem == 0:\n",
    "    # Translating hump\n",
    "    mesh = \"periodic-hexagon.mesh\"\n",
    "    ref_levels = 2\n",
    "    dt = 0.01\n",
    "    t_final = 10\n",
    "elif problem == 1:\n",
    "    # Rotating power sines\n",
    "    mesh = \"periodic-square.mesh\"\n",
    "    ref_levels = 2\n",
    "    dt = 0.005\n",
    "    t_final = 9\n",
    "elif problem == 2:\n",
    "    # Rotating trigs\n",
    "    mesh = \"disc-nurbs.mesh\"\n",
    "    ref_levels = 3\n",
    "    dt = 0.005\n",
    "    t_final = 9\n",
    "elif problem == 3:\n",
    "    # Twisting sines\n",
    "    mesh = \"periodic-square.mesh\"\n",
    "    ref_levels = 4\n",
    "    dt = 0.0025\n",
    "    t_final = 9\n",
    "    vis_steps = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. Download the mesh from GitHub and read it. \n",
    "#    We can handle geometrically periodic meshes in this code.\n",
    "\n",
    "mesh_url = f\"https://raw.githubusercontent.com/mfem/mfem/master/data/{mesh}\"\n",
    "mesh_str = !curl -s {mesh_url}\n",
    "with open(mesh, \"w\") as f:\n",
    "    f.write(\"\\n\".join(mesh_str))\n",
    "mesh = mfem.Mesh(mesh, 1,1)\n",
    "dim = mesh.Dimension()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. Define the ODE solver used for time integration. \n",
    "#    Several explicit Runge-Kutta methods are available.\n",
    "\n",
    "ode_solver = None\n",
    "if ode_solver_type == 1:   ode_solver = mfem.ForwardEulerSolver()\n",
    "elif ode_solver_type == 2: ode_solver = mfem.RK2Solver(1.0)\n",
    "elif ode_solver_type == 3: ode_solver = mfem.RK3SSolver()\n",
    "elif ode_solver_type == 4: ode_solver = mfem.RK4Solver()\n",
    "elif ode_solver_type == 6: ode_solver = mfem.RK6Solver()\n",
    "else: print(f\"Unknown ODE solver type: {ode_solver_type}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. Refine the mesh to increase the resolution. In this example we do\n",
    "#    'ref_levels' of uniform refinement. If the mesh is of NURBS type, \n",
    "#    we convert it to a (piecewise-polynomial) high-order mesh.\n",
    "\n",
    "for lev in range(ref_levels):\n",
    "    mesh.UniformRefinement();\n",
    "    if mesh.NURBSext:\n",
    "        mesh.SetCurvature(max(order, 1))\n",
    "    bb_min, bb_max = mesh.GetBoundingBox(max(order, 1));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. Define the discontinuous DG finite element space of the given\n",
    "#    polynomial order on the refined mesh.\n",
    "\n",
    "fec = mfem.DG_FECollection(order, dim)\n",
    "fes = mfem.FiniteElementSpace(mesh, fec)\n",
    "\n",
    "print(\"Number of unknowns: \" + str(fes.GetVSize()))\n",
    "\n",
    "# Define coefficient using VecotrPyCoefficient and PyCoefficient\n",
    "# A user needs to define EvalValue method        \n",
    "class velocity_coeff(mfem.VectorPyCoefficient):\n",
    "    def EvalValue(self, x):        \n",
    "        dim = len(x)\n",
    "        \n",
    "        center = (bb_min + bb_max)/2.0\n",
    "        # map to the reference [-1,1] domain                \n",
    "        X = 2 * (x - center) / (bb_max - bb_min)\n",
    "        if problem == 0:\n",
    "            if dim == 1: v = [1.0,]\n",
    "            elif dim == 2: v = [sqrt(2./3.), sqrt(1./3)]\n",
    "            elif dim == 3: v = [sqrt(3./6.), sqrt(2./6), sqrt(1./6.)]\n",
    "        elif (problem == 1 or problem == 2):\n",
    "            # Clockwise rotation in 2D around the origin                \n",
    "            w = pi/2\n",
    "            if dim == 1: v = [1.0,]\n",
    "            elif dim == 2: v = [w*X[1],  - w*X[0]]\n",
    "            elif dim == 3: v = [w*X[1],  - w*X[0],  0]\n",
    "        elif (problem == 3):\n",
    "            # Clockwise twisting rotation in 2D around the origin\n",
    "            w = pi/2\n",
    "            d = max((X[0]+1.)*(1.-X[0]),0.) * max((X[1]+1.)*(1.-X[1]),0.)\n",
    "            d = d ** 2\n",
    "            if dim == 1: v = [1.0,]\n",
    "            elif dim == 2: v = [d*w*X[1],  - d*w*X[0]]\n",
    "            elif dim == 3: v = [d*w*X[1],  - d*w*X[0],  0]\n",
    "        return v\n",
    "      \n",
    "        \n",
    "class u0_coeff(mfem.PyCoefficient):\n",
    "    def EvalValue(self, x):        \n",
    "        dim = len(x)\n",
    "        center = (bb_min + bb_max)/2.0\n",
    "        # map to the reference [-1,1] domain        \n",
    "        X = 2 * (x - center) / (bb_max - bb_min)\n",
    "        if (problem == 0 or problem == 1):\n",
    "            if dim == 1:\n",
    "                return exp(-40. * (X[0]-0.5)**2)\n",
    "            elif (dim == 2 or dim == 3):\n",
    "                rx = 0.45; ry = 0.25; cx = 0.; cy = -0.2; w = 10.\n",
    "                if dim == 3:\n",
    "                    s = (1. + 0.25*cos(2 * pi * x[2]))\n",
    "                    rx = rx * s\n",
    "                    ry = ry * s\n",
    "                return ( erfc( w * (X[0]-cx-rx)) * erfc(-w*(X[0]-cx+rx)) *\n",
    "                         erfc( w * (X[1]-cy-ry)) * erfc(-w*(X[1]-cy+ry)) )/16\n",
    "        elif problem == 2:\n",
    "            rho = hypot(x[0], x[1])\n",
    "            phi = arctan2(x[1], x[0])\n",
    "            return (sin(pi * rho) **2) * sin(3*phi)\n",
    "        elif problem == 3:\n",
    "            return sin(pi * X[0]) * sin(pi * X[1])\n",
    "       \n",
    "        return 0.0\n",
    "\n",
    "# Inflow boundary condition (zero for the problems considered in \n",
    "# this example)\n",
    "class inflow_coeff(mfem.PyCoefficient):\n",
    "    def EvalValue(self, x):\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 6. Set up and assemble the bilinear and linear forms corresponding to \n",
    "#    the DG discretization. The DGTraceIntegrator involves integrals over\n",
    "#    mesh interior faces.\n",
    "\n",
    "velocity = velocity_coeff(dim)\n",
    "inflow = inflow_coeff()\n",
    "u0 = u0_coeff()\n",
    "\n",
    "m = mfem.BilinearForm(fes)\n",
    "m.AddDomainIntegrator(mfem.MassIntegrator())\n",
    "k = mfem.BilinearForm(fes)\n",
    "k.AddDomainIntegrator(mfem.ConvectionIntegrator(velocity, -1.0))\n",
    "k.AddInteriorFaceIntegrator(mfem.TransposeIntegrator(mfem.DGTraceIntegrator(velocity, 1.0, -0.5)))\n",
    "k.AddBdrFaceIntegrator(mfem.TransposeIntegrator(mfem.DGTraceIntegrator(velocity, 1.0, -0.5)))\n",
    "b = mfem.LinearForm(fes)\n",
    "b.AddBdrFaceIntegrator(mfem.BoundaryFlowIntegrator(inflow, velocity, -1.0, -0.5))\n",
    "\n",
    "m.Assemble()\n",
    "m.Finalize()\n",
    "skip_zeros = 0\n",
    "k.Assemble(skip_zeros)\n",
    "k.Finalize(skip_zeros)\n",
    "b.Assemble()\n",
    "\n",
    "# Initial conditions\n",
    "u = mfem.GridFunction(fes)\n",
    "u.ProjectCoefficient(u0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solve and Plot the Solution with GLVis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = glvis((mesh, u))\n",
    "g.show()\n",
    "\n",
    "class FE_Evolution(mfem.PyTimeDependentOperator):\n",
    "    def __init__(self, M, K, b):\n",
    "        mfem.PyTimeDependentOperator.__init__(self, M.Size())\n",
    "        self.K = K        \n",
    "        self.M = M\n",
    "        self.b = b\n",
    "        self.z = mfem.Vector(M.Size())\n",
    "        self.zp = np.zeros(M.Size())\n",
    "        self.M_prec = mfem.DSmoother()        \n",
    "        self.M_solver = mfem.CGSolver()\n",
    "        self.M_solver.SetPreconditioner(self.M_prec)        \n",
    "        self.M_solver.SetOperator(M)\n",
    "        self.M_solver.iterative_mode = False\n",
    "        self.M_solver.SetRelTol(1e-9)\n",
    "        self.M_solver.SetAbsTol(0.0)\n",
    "        self.M_solver.SetMaxIter(100)\n",
    "        self.M_solver.SetPrintLevel(0)\n",
    "        \n",
    "    def Mult(self, x, y):\n",
    "        self.K.Mult(x, self.z);\n",
    "        self.z += b;\n",
    "        self.M_solver.Mult(self.z, y)\n",
    "\n",
    "adv = FE_Evolution(m.SpMat(), k.SpMat(), b)\n",
    "\n",
    "ode_solver.Init(adv)\n",
    "t = 0.0; ti = 0;\n",
    "time.sleep(1)\n",
    "while True:\n",
    "    if t > t_final - dt/2: break\n",
    "    t, dt = ode_solver.Step(u, t, dt);\n",
    "    ti = ti + 1\n",
    "    if ti % vis_steps == 0:\n",
    "        g.update((mesh, u))\n",
    "        time.sleep(sleep_interval)\n",
    "        print(f\"time step: {ti}, time: {t:.2e}\", end=\"\\r\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
